# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

from typing import Optional, Tuple

import torch
import random
import ttnn

from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time
from models.common.utility_functions import torch_random

TIMEOUT = 10
random.seed(0)

parameters = {
    "nightly": {
        "shape": [
            [1, 10],
            [1, 128],
            [1, 512],
            [1, 768],
            [10, 128],
            [1000, 1008],
            [1000, 1024],
            [1000, 1280],
            [1000, 1512],
            [1000, 1536],
            [1000, 1664],
            [1000, 1920],
            [1000, 2016],
            [1000, 2048],
            [1000, 2208],
            [1000, 2520],
            [1000, 3024],
            [1000, 3712],
            [1000, 400],
            [1000, 4096],
            [1000, 440],
            [1000, 512],
            [1000, 672],
            [1000, 7392],
            [1000, 768],
            [1000, 784],
            [1000, 888],
            [1000, 912],
            [1024, 1024],
            [1024, 128],
            [1024, 2048],
            [1024, 256],
            [1024, 4096],
            [1024, 512],
            [1024, 576],
            [10240, 1280],
            [1152, 384],
            [12, 3],
            [12, 512],
            [12, 64],
            [128, 1024],
            [128, 10],
            [128, 128],
            [128, 2048],
            [128, 32],
            [128, 4096],
            [128, 512],
            [128, 64],
            [128, 784],
            [128, 9216],
            [1280, 1280],
            [1280, 320],
            [1280, 5120],
            [1280, 768],
            [1280, 960],
            [1536, 1536],
            [1536, 384],
            [1536, 512],
            [1536, 6144],
            [1536, 768],
            [16, 512],
            [160, 160],
            [160, 640],
            [16384, 4096],
            [192, 192],
            [192, 384],
            [192, 768],
            [196, 384],
            [2, 1024],
            [2, 1],
            [2, 768],
            [2048, 128],
            [2048, 2048],
            [2048, 256],
            [2048, 512],
            [2048, 8192],
            [21843, 768],
            [2304, 768],
            [24, 512],
            [250002, 768],
            [250880, 1536],
            [256, 1024],
            [256, 1280],
            [256, 160],
            [256, 2048],
            [256, 256],
            [256, 32],
            [256, 512],
            [256, 64],
            [256, 768],
            [256, 80],
            [2560, 320],
            [256008, 1024],
            [288, 96],
            [3, 12],
            [3, 512],
            [30000, 128],
            [3072, 1024],
            [3072, 768],
            [3129, 1536],
            [32, 128],
            [32, 32],
            [32, 512],
            [320, 1280],
            [320, 320],
            [320, 768],
            [32128, 1024],
            [32128, 512],
            [32128, 768],
            [384, 128],
            [384, 1536],
            [384, 196],
            [384, 384],
            [384, 512],
            [384, 768],
            [384, 96],
            [3840, 1280],
            [4, 192],
            [4, 256],
            [4, 512],
            [4096, 1024],
            [4096, 128],
            [4096, 16384],
            [4096, 25088],
            [4096, 4096],
            [4608, 1536],
            [50257, 768],
            [50272, 512],
            [512, 1024],
            [512, 128],
            [512, 2048],
            [512, 256],
            [512, 2],
            [512, 384],
            [512, 512],
            [512, 768],
            [5120, 1280],
            [5120, 640],
            [51200, 1024],
            [51865, 768],
            [576, 192],
            [6, 512],
            [6144, 1536],
            [64, 128],
            [64, 12],
            [64, 256],
            [64, 64],
            [640, 1280],
            [640, 160],
            [640, 2560],
            [640, 640],
            [640, 768],
            [65024, 4544],
            [768, 1280],
            [768, 128],
            [768, 1536],
            [768, 192],
            [768, 256],
            [768, 3072],
            [768, 768],
            [784, 128],
            [8, 512],
            [8192, 2048],
            [92, 192],
            [92, 256],
            [9216, 128],
            [96, 384],
            [96, 96],
        ],
        "dtype": [ttnn.bfloat16],
        "layout": [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT],
    }
}


# Invalidate vector is called during the generation phase where each vector will be passed in.
# If invalidated, the vector will still be stored but will be skipped.
# Returns False, None if the vector is valid, and True, str with a reason for invalidation if it is invalid.
def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]:
    if test_vector["layout"] == ttnn.ROW_MAJOR_LAYOUT:
        if test_vector["dtype"] == ttnn.bfloat8_b:
            return True, "bfloat8_b not supported with ROW_MAJOR_LAYOUT"
    if test_vector["dtype"] == ttnn.bfloat8_b:
        if len(test_vector["slice_specs"]["dims"]) < 2:
            return True, "bfloat8_b not supported with dims  < 2"

    return False, None


def run(
    shape,
    dtype,
    layout,
    *,
    device,
):
    raise Exception(".T is not supported, TODO: bind to transpose")
